{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this notebook, we'll learn how to use GANs to do semi-supervised learning.\n",
    "\n",
    "In supervised learning, we have a training set of inputs $x$ and class labels $y$. We train a model that takes $x$ as input and gives $y$ as output.\n",
    "\n",
    "In semi-supervised learning, our goal is still to train a model that takes $x$ as input and generates $y$ as output. However, not all of our training examples have a label $y$. We need to develop an algorithm that is able to get better at classification by studying both labeled $(x, y)$ pairs and unlabeled $x$ examples.\n",
    "\n",
    "To do this for the SVHN dataset, we'll turn the GAN discriminator into an 11 class discriminator. It will recognize the 10 different classes of real SVHN digits, as well as an 11th class of fake images that come from the generator. The discriminator will get to train on real labeled images, real unlabeled images, and fake images. By drawing on three sources of data instead of just one, it will generalize to the test set much better than a traditional classifier trained on only one source of data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "import pickle as pkl\n",
    "import time\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "from scipy.io import loadmat\n",
    "import tensorflow as tf\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "!mkdir data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from urllib.request import urlretrieve\n",
    "from os.path import isfile, isdir\n",
    "from tqdm import tqdm\n",
    "\n",
    "data_dir = 'data/'\n",
    "\n",
    "if not isdir(data_dir):\n",
    "    raise Exception(\"Data directory doesn't exist!\")\n",
    "\n",
    "class DLProgress(tqdm):\n",
    "    last_block = 0\n",
    "\n",
    "    def hook(self, block_num=1, block_size=1, total_size=None):\n",
    "        self.total = total_size\n",
    "        self.update((block_num - self.last_block) * block_size)\n",
    "        self.last_block = block_num\n",
    "\n",
    "if not isfile(data_dir + \"train_32x32.mat\"):\n",
    "    with DLProgress(unit='B', unit_scale=True, miniters=1, desc='SVHN Training Set') as pbar:\n",
    "        urlretrieve(\n",
    "            'http://ufldl.stanford.edu/housenumbers/train_32x32.mat',\n",
    "            data_dir + 'train_32x32.mat',\n",
    "            pbar.hook)\n",
    "\n",
    "if not isfile(data_dir + \"test_32x32.mat\"):\n",
    "    with DLProgress(unit='B', unit_scale=True, miniters=1, desc='SVHN Training Set') as pbar:\n",
    "        urlretrieve(\n",
    "            'http://ufldl.stanford.edu/housenumbers/test_32x32.mat',\n",
    "            data_dir + 'test_32x32.mat',\n",
    "            pbar.hook)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "trainset = loadmat(data_dir + 'train_32x32.mat')\n",
    "testset = loadmat(data_dir + 'test_32x32.mat')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "idx = np.random.randint(0, trainset['X'].shape[3], size=36)\n",
    "fig, axes = plt.subplots(6, 6, sharex=True, sharey=True, figsize=(5,5),)\n",
    "for ii, ax in zip(idx, axes.flatten()):\n",
    "    ax.imshow(trainset['X'][:,:,:,ii], aspect='equal')\n",
    "    ax.xaxis.set_visible(False)\n",
    "    ax.yaxis.set_visible(False)\n",
    "plt.subplots_adjust(wspace=0, hspace=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def scale(x, feature_range=(-1, 1)):\n",
    "    # scale to (0, 1)\n",
    "    x = ((x - x.min())/(255 - x.min()))\n",
    "    \n",
    "    # scale to feature_range\n",
    "    min, max = feature_range\n",
    "    x = x * (max - min) + min\n",
    "    return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class Dataset:\n",
    "    def __init__(self, train, test, val_frac=0.5, shuffle=True, scale_func=None):\n",
    "        split_idx = int(len(test['y'])*(1 - val_frac))\n",
    "        self.test_x, self.valid_x = test['X'][:,:,:,:split_idx], test['X'][:,:,:,split_idx:]\n",
    "        self.test_y, self.valid_y = test['y'][:split_idx], test['y'][split_idx:]\n",
    "        self.train_x, self.train_y = train['X'], train['y']\n",
    "        # The SVHN dataset comes with lots of labels, but for the purpose of this exercise,\n",
    "        # we will pretend that there are only 1000.\n",
    "        # We use this mask to say which labels we will allow ourselves to use.\n",
    "        self.label_mask = np.zeros_like(self.train_y)\n",
    "        self.label_mask[0:1000] = 1\n",
    "        \n",
    "        self.train_x = np.rollaxis(self.train_x, 3)\n",
    "        self.valid_x = np.rollaxis(self.valid_x, 3)\n",
    "        self.test_x = np.rollaxis(self.test_x, 3)\n",
    "        \n",
    "        if scale_func is None:\n",
    "            self.scaler = scale\n",
    "        else:\n",
    "            self.scaler = scale_func\n",
    "        self.train_x = self.scaler(self.train_x)\n",
    "        self.valid_x = self.scaler(self.valid_x)\n",
    "        self.test_x = self.scaler(self.test_x)\n",
    "        self.shuffle = shuffle\n",
    "        \n",
    "    def batches(self, batch_size, which_set=\"train\"):\n",
    "        x_name = which_set + \"_x\"\n",
    "        y_name = which_set + \"_y\"\n",
    "        \n",
    "        num_examples = len(getattr(dataset, y_name))\n",
    "        if self.shuffle:\n",
    "            idx = np.arange(num_examples)\n",
    "            np.random.shuffle(idx)\n",
    "            setattr(dataset, x_name, getattr(dataset, x_name)[idx])\n",
    "            setattr(dataset, y_name, getattr(dataset, y_name)[idx])\n",
    "            if which_set == \"train\":\n",
    "                dataset.label_mask = dataset.label_mask[idx]\n",
    "        \n",
    "        dataset_x = getattr(dataset, x_name)\n",
    "        dataset_y = getattr(dataset, y_name)\n",
    "        for ii in range(0, num_examples, batch_size):\n",
    "            x = dataset_x[ii:ii+batch_size]\n",
    "            y = dataset_y[ii:ii+batch_size]\n",
    "            \n",
    "            if which_set == \"train\":\n",
    "                # When we use the data for training, we need to include\n",
    "                # the label mask, so we can pretend we don't have access\n",
    "                # to some of the labels, as an exercise of our semi-supervised\n",
    "                # learning ability\n",
    "                yield x, y, self.label_mask[ii:ii+batch_size]\n",
    "            else:\n",
    "                yield x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def model_inputs(real_dim, z_dim):\n",
    "    inputs_real = tf.placeholder(tf.float32, (None, *real_dim), name='input_real')\n",
    "    inputs_z = tf.placeholder(tf.float32, (None, z_dim), name='input_z')\n",
    "    y = tf.placeholder(tf.int32, (None), name='y')\n",
    "    label_mask = tf.placeholder(tf.int32, (None), name='label_mask')\n",
    "    \n",
    "    return inputs_real, inputs_z, y, label_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def generator(z, output_dim, reuse=False, alpha=0.2, training=True, size_mult=128):\n",
    "    with tf.variable_scope('generator', reuse=reuse):\n",
    "        # First fully connected layer\n",
    "        x1 = tf.layers.dense(z, 4 * 4 * size_mult * 4)\n",
    "        # Reshape it to start the convolutional stack\n",
    "        x1 = tf.reshape(x1, (-1, 4, 4, size_mult * 4))\n",
    "        x1 = tf.layers.batch_normalization(x1, training=training)\n",
    "        x1 = tf.maximum(alpha * x1, x1)\n",
    "        \n",
    "        x2 = tf.layers.conv2d_transpose(x1, size_mult * 2, 5, strides=2, padding='same')\n",
    "        x2 = tf.layers.batch_normalization(x2, training=training)\n",
    "        x2 = tf.maximum(alpha * x2, x2)\n",
    "        \n",
    "        x3 = tf.layers.conv2d_transpose(x2, size_mult, 5, strides=2, padding='same')\n",
    "        x3 = tf.layers.batch_normalization(x3, training=training)\n",
    "        x3 = tf.maximum(alpha * x3, x3)\n",
    "        \n",
    "        # Output layer\n",
    "        logits = tf.layers.conv2d_transpose(x3, output_dim, 5, strides=2, padding='same')\n",
    "        \n",
    "        out = tf.tanh(logits)\n",
    "        \n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def discriminator(x, reuse=False, alpha=0.2, drop_rate=0., num_classes=10, size_mult=64):\n",
    "    with tf.variable_scope('discriminator', reuse=reuse):\n",
    "        x = tf.layers.dropout(x, rate=drop_rate/2.5)\n",
    "        \n",
    "        # Input layer is 32x32x3\n",
    "        x1 = tf.layers.conv2d(x, size_mult, 3, strides=2, padding='same')\n",
    "        relu1 = tf.maximum(alpha * x1, x1)\n",
    "        relu1 = tf.layers.dropout(relu1, rate=drop_rate)\n",
    "        \n",
    "        x2 = tf.layers.conv2d(relu1, size_mult, 3, strides=2, padding='same')\n",
    "        bn2 = tf.layers.batch_normalization(x2, training=True)\n",
    "        relu2 = tf.maximum(alpha * x2, x2)\n",
    "        \n",
    "        \n",
    "        x3 = tf.layers.conv2d(relu2, size_mult, 3, strides=2, padding='same')\n",
    "        bn3 = tf.layers.batch_normalization(x3, training=True)\n",
    "        relu3 = tf.maximum(alpha * bn3, bn3)\n",
    "        relu3 = tf.layers.dropout(relu3, rate=drop_rate)\n",
    "        \n",
    "        x4 = tf.layers.conv2d(relu3, 2 * size_mult, 3, strides=1, padding='same')\n",
    "        bn4 = tf.layers.batch_normalization(x4, training=True)\n",
    "        relu4 = tf.maximum(alpha * bn4, bn4)\n",
    "        \n",
    "        x5 = tf.layers.conv2d(relu4, 2 * size_mult, 3, strides=1, padding='same')\n",
    "        bn5 = tf.layers.batch_normalization(x5, training=True)\n",
    "        relu5 = tf.maximum(alpha * bn5, bn5)\n",
    "        \n",
    "        x6 = tf.layers.conv2d(relu5, 2 * size_mult, 3, strides=1, padding='valid')\n",
    "        # Don't use bn on this layer, because bn would set the mean of each feature\n",
    "        # to the bn mu parameter.\n",
    "        # This layer is used for the feature matching loss, which only works if\n",
    "        # the means can be different when the discriminator is run on the data than\n",
    "        # when the discriminator is run on the generator samples.\n",
    "        relu6 = tf.maximum(alpha * x6, x6)\n",
    "        \n",
    "        # Flatten it by global average pooling\n",
    "        features = raise NotImplementedError()\n",
    "        \n",
    "        # Set class_logits to be the inputs to a softmax distribution over the different classes\n",
    "        raise NotImplementedError()\n",
    "        \n",
    "        \n",
    "        # Set gan_logits such that P(input is real | input) = sigmoid(gan_logits).\n",
    "        # Keep in mind that class_logits gives you the probability distribution over all the real\n",
    "        # classes and the fake class. You need to work out how to transform this multiclass softmax\n",
    "        # distribution into a binary real-vs-fake decision that can be described with a sigmoid.\n",
    "        # Numerical stability is very important.\n",
    "        # You'll probably need to use this numerical stability trick:\n",
    "        # log sum_i exp a_i = m + log sum_i exp(a_i - m).\n",
    "        # This is numerically stable when m = max_i a_i.\n",
    "        # (It helps to think about what goes wrong when...\n",
    "        #   1. One value of a_i is very large\n",
    "        #   2. All the values of a_i are very negative\n",
    "        # This trick and this value of m fix both those cases, but the naive implementation and\n",
    "        # other values of m encounter various problems)\n",
    "        \n",
    "        raise NotImplementedError()\n",
    "        \n",
    "        return out, class_logits, gan_logits, features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def model_loss(input_real, input_z, output_dim, y, num_classes, label_mask, alpha=0.2, drop_rate=0.):\n",
    "    \"\"\"\n",
    "    Get the loss for the discriminator and generator\n",
    "    :param input_real: Images from the real dataset\n",
    "    :param input_z: Z input\n",
    "    :param output_dim: The number of channels in the output image\n",
    "    :param y: Integer class labels\n",
    "    :param num_classes: The number of classes\n",
    "    :param alpha: The slope of the left half of leaky ReLU activation\n",
    "    :param drop_rate: The probability of dropping a hidden unit\n",
    "    :return: A tuple of (discriminator loss, generator loss)\n",
    "    \"\"\"\n",
    "    \n",
    "    \n",
    "    # These numbers multiply the size of each layer of the generator and the discriminator,\n",
    "    # respectively. You can reduce them to run your code faster for debugging purposes.\n",
    "    g_size_mult = 32\n",
    "    d_size_mult = 64\n",
    "    \n",
    "    # Here we run the generator and the discriminator\n",
    "    g_model = generator(input_z, output_dim, alpha=alpha, size_mult=g_size_mult)\n",
    "    d_on_data = discriminator(input_real, alpha=alpha, drop_rate=drop_rate, size_mult=d_size_mult)\n",
    "    d_model_real, class_logits_on_data, gan_logits_on_data, data_features = d_on_data\n",
    "    d_on_samples = discriminator(g_model, reuse=True, alpha=alpha, drop_rate=drop_rate, size_mult=d_size_mult)\n",
    "    d_model_fake, class_logits_on_samples, gan_logits_on_samples, sample_features = d_on_samples\n",
    "    \n",
    "    \n",
    "    # Here we compute `d_loss`, the loss for the discriminator.\n",
    "    # This should combine two different losses:\n",
    "    #  1. The loss for the GAN problem, where we minimize the cross-entropy for the binary\n",
    "    #     real-vs-fake classification problem.\n",
    "    #  2. The loss for the SVHN digit classification problem, where we minimize the cross-entropy\n",
    "    #     for the multi-class softmax. For this one we use the labels. Don't forget to ignore\n",
    "    #     use `label_mask` to ignore the examples that we are pretending are unlabeled for the\n",
    "    #     semi-supervised learning problem.\n",
    "    raise NotImplementedError()\n",
    "    \n",
    "    # Here we set `g_loss` to the \"feature matching\" loss invented by Tim Salimans at OpenAI.\n",
    "    # This loss consists of minimizing the absolute difference between the expected features\n",
    "    # on the data and the expected features on the generated samples.\n",
    "    # This loss works better for semi-supervised learning than the tradition GAN losses.\n",
    "    raise NotImplementedError()\n",
    "\n",
    "    pred_class = tf.cast(tf.argmax(class_logits_on_data, 1), tf.int32)\n",
    "    eq = tf.equal(tf.squeeze(y), pred_class)\n",
    "    correct = tf.reduce_sum(tf.to_float(eq))\n",
    "    masked_correct = tf.reduce_sum(label_mask * tf.to_float(eq))\n",
    "    \n",
    "    return d_loss, g_loss, correct, masked_correct, g_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def model_opt(d_loss, g_loss, learning_rate, beta1):\n",
    "    \"\"\"\n",
    "    Get optimization operations\n",
    "    :param d_loss: Discriminator loss Tensor\n",
    "    :param g_loss: Generator loss Tensor\n",
    "    :param learning_rate: Learning Rate Placeholder\n",
    "    :param beta1: The exponential decay rate for the 1st moment in the optimizer\n",
    "    :return: A tuple of (discriminator training operation, generator training operation)\n",
    "    \"\"\"\n",
    "    # Get weights and biases to update. Get them separately for the discriminator and the generator\n",
    "    raise NotImplementedError()\n",
    "\n",
    "    # Minimize both players' costs simultaneously\n",
    "    raise NotImplementedError()\n",
    "    shrink_lr = tf.assign(learning_rate, learning_rate * 0.9)\n",
    "    \n",
    "    return d_train_opt, g_train_opt, shrink_lr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class GAN:\n",
    "    \"\"\"\n",
    "    A GAN model.\n",
    "    :param real_size: The shape of the real data.\n",
    "    :param z_size: The number of entries in the z code vector.\n",
    "    :param learnin_rate: The learning rate to use for Adam.\n",
    "    :param num_classes: The number of classes to recognize.\n",
    "    :param alpha: The slope of the left half of the leaky ReLU activation\n",
    "    :param beta1: The beta1 parameter for Adam.\n",
    "    \"\"\"\n",
    "    def __init__(self, real_size, z_size, learning_rate, num_classes=10, alpha=0.2, beta1=0.5):\n",
    "        tf.reset_default_graph()\n",
    "        \n",
    "        self.learning_rate = tf.Variable(learning_rate, trainable=False)\n",
    "        inputs = model_inputs(real_size, z_size)\n",
    "        self.input_real, self.input_z, self.y, self.label_mask = inputs\n",
    "        self.drop_rate = tf.placeholder_with_default(.5, (), \"drop_rate\")\n",
    "        \n",
    "        loss_results = model_loss(self.input_real, self.input_z,\n",
    "                                  real_size[2], self.y, num_classes,\n",
    "                                  label_mask=self.label_mask,\n",
    "                                  alpha=0.2,\n",
    "                                  drop_rate=self.drop_rate)\n",
    "        self.d_loss, self.g_loss, self.correct, self.masked_correct, self.samples = loss_results\n",
    "        \n",
    "        self.d_opt, self.g_opt, self.shrink_lr = model_opt(self.d_loss, self.g_loss, self.learning_rate, beta1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def view_samples(epoch, samples, nrows, ncols, figsize=(5,5)):\n",
    "    fig, axes = plt.subplots(figsize=figsize, nrows=nrows, ncols=ncols, \n",
    "                             sharey=True, sharex=True)\n",
    "    for ax, img in zip(axes.flatten(), samples[epoch]):\n",
    "        ax.axis('off')\n",
    "        img = ((img - img.min())*255 / (img.max() - img.min())).astype(np.uint8)\n",
    "        ax.set_adjustable('box-forced')\n",
    "        im = ax.imshow(img)\n",
    "   \n",
    "    plt.subplots_adjust(wspace=0, hspace=0)\n",
    "    return fig, axes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def train(net, dataset, epochs, batch_size, figsize=(5,5)):\n",
    "    \n",
    "    saver = tf.train.Saver()\n",
    "    sample_z = np.random.normal(0, 1, size=(50, z_size))\n",
    "\n",
    "    samples, train_accuracies, test_accuracies = [], [], []\n",
    "    steps = 0\n",
    "\n",
    "    with tf.Session() as sess:\n",
    "        sess.run(tf.global_variables_initializer())\n",
    "        for e in range(epochs):\n",
    "            print(\"Epoch\",e)\n",
    "            \n",
    "            t1e = time.time()\n",
    "            num_examples = 0\n",
    "            num_correct = 0\n",
    "            for x, y, label_mask in dataset.batches(batch_size):\n",
    "                assert 'int' in str(y.dtype)\n",
    "                steps += 1\n",
    "                num_examples += label_mask.sum()\n",
    "\n",
    "                # Sample random noise for G\n",
    "                batch_z = np.random.normal(0, 1, size=(batch_size, z_size))\n",
    "\n",
    "                # Run optimizers\n",
    "                t1 = time.time()\n",
    "                _, _, correct = sess.run([net.d_opt, net.g_opt, net.masked_correct],\n",
    "                                         feed_dict={net.input_real: x, net.input_z: batch_z,\n",
    "                                                    net.y : y, net.label_mask : label_mask})\n",
    "                t2 = time.time()\n",
    "                num_correct += correct\n",
    "\n",
    "            sess.run([net.shrink_lr])\n",
    "            \n",
    "            \n",
    "            train_accuracy = num_correct / float(num_examples)\n",
    "            \n",
    "            print(\"\\t\\tClassifier train accuracy: \", train_accuracy)\n",
    "            \n",
    "            num_examples = 0\n",
    "            num_correct = 0\n",
    "            for x, y in dataset.batches(batch_size, which_set=\"test\"):\n",
    "                assert 'int' in str(y.dtype)\n",
    "                num_examples += x.shape[0]\n",
    "\n",
    "                correct, = sess.run([net.correct], feed_dict={net.input_real: x,\n",
    "                                                   net.y : y,\n",
    "                                                   net.drop_rate: 0.})\n",
    "                num_correct += correct\n",
    "            \n",
    "            test_accuracy = num_correct / float(num_examples)\n",
    "            print(\"\\t\\tClassifier test accuracy\", test_accuracy)\n",
    "            print(\"\\t\\tStep time: \", t2 - t1)\n",
    "            t2e = time.time()\n",
    "            print(\"\\t\\tEpoch time: \", t2e - t1e)\n",
    "            \n",
    "            \n",
    "            gen_samples = sess.run(\n",
    "                                   net.samples,\n",
    "                                   feed_dict={net.input_z: sample_z})\n",
    "            samples.append(gen_samples)\n",
    "            _ = view_samples(-1, samples, 5, 10, figsize=figsize)\n",
    "            plt.show()\n",
    "            \n",
    "            \n",
    "            # Save history of accuracies to view after training\n",
    "            train_accuracies.append(train_accuracy)\n",
    "            test_accuracies.append(test_accuracy)\n",
    "            \n",
    "\n",
    "        saver.save(sess, './checkpoints/generator.ckpt')\n",
    "\n",
    "    with open('samples.pkl', 'wb') as f:\n",
    "        pkl.dump(samples, f)\n",
    "    \n",
    "    return train_accuracies, test_accuracies, samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "!mkdir checkpoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "real_size = (32,32,3)\n",
    "z_size = 100\n",
    "learning_rate = 0.0003\n",
    "\n",
    "net = GAN(real_size, z_size, learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "dataset = Dataset(trainset, testset)\n",
    "\n",
    "batch_size = 128\n",
    "epochs = 25\n",
    "train_accuracies, test_accuracies, samples = train(net,\n",
    "                                                   dataset,\n",
    "                                                   epochs,\n",
    "                                                   batch_size,\n",
    "                                                   figsize=(10,5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "fig, ax = plt.subplots()\n",
    "plt.plot(train_accuracies, label='Train', alpha=0.5)\n",
    "plt.plot(test_accuracies, label='Test', alpha=0.5)\n",
    "plt.title(\"Accuracy\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When you run the fully implemented semi-supervised GAN, you should usually find that the test accuracy peaks at 69-71%. It should definitely stay above 68% fairly consistently throughout the last several epochs of training.\n",
    "\n",
    "This is a little bit better than a [NIPS 2014 paper](https://arxiv.org/pdf/1406.5298.pdf) that got 64% accuracy on 1000-label SVHN with variational methods. However, we still have lost something by not using all the labels. If you re-run with all the labels included, you should obtain over 80% accuracy using this architecture (and other architectures that take longer to run can do much better)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "_ = view_samples(-1, samples, 5, 10, figsize=(10,5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "!mkdir images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "for ii in range(len(samples)):\n",
    "    fig, ax = view_samples(ii, samples, 5, 10, figsize=(10,5))\n",
    "    fig.savefig('images/samples_{:03d}.png'.format(ii))\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "Congratulations! You now know how to train a semi-supervised GAN. This exercise is stripped down to make it run faster and to make it simpler to implement. In the original work by Tim Salimans at OpenAI, a GAN using [more tricks and more runtime](https://arxiv.org/pdf/1606.03498.pdf) reaches over 94% accuracy using only 1,000 labeled examples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
